######################################################################################88
import argparse
from os.path import exists
import sys
import os

parser = argparse.ArgumentParser(
    description='Run the CoNGA clonotype neighbor-graph analysis pipeline.',
    epilog = f'''
==============================================================================
Run the CoNGA clonotype neighbor-graph analysis pipeline.

see https://github.com/phbradley/conga for more information.

Requires that a tsv formatted "clones_file" (and barcode mapping file)
has already been generated by the script setup_10x_for_conga.py

Minimal command line arguments would be:

    --gex_data
    --gex_data_type
    --clones_file
    --organism
    --outfile_prefix (string that will be prepended to all outputs)

    and flags to trigger the desired analyses, for example

    --graph_vs_graph
    --graph_vs_features
    --tcr_clumping
    --match_to_tcr_database

    --all (runs all the major analyses)

You can also provide all commandline arguments in a yaml-formatted file
    and pass that in with

    --config

See above for argument descriptions.

Examples:

    python3 {sys.argv[0]} --graph_vs_graph --graph_vs_features --gex_data vdj_v1_hs_pbmc3_5gex_filtered_gene_bc_matrices_h5.h5 --gex_data_type 10x_h5 --clones_file vdj_v1_hs_pbmc3_t_filtered_contig_annotations_tcrdist_clones.tsv --organism human --outfile_prefix tmp_hs_pbmc3

    python3 {sys.argv[0]} --all --gex_data mouse_cd4_filtered_feature_bc_matrix/ --gex_data_type 10x_mtx --clones_file mouse_cd4_filtered_contig_annotations_tcrdist_clones.tsv --organism mouse --outfile_prefix tmp_mouse_cd4
    ''',
    formatter_class=argparse.RawDescriptionHelpFormatter,
    )

# core args
parser.add_argument('--version', action='version', version='CoNGA version 0.1.1')
parser.add_argument('--config', help="configuration file *.yml", type=str)
parser.add_argument('--gex_data',
                    help='Input file with the single-cell gene expression data')
parser.add_argument('--gex_data_type',
                    choices=['h5ad', '10x_mtx', '10x_h5'],
                    help='Format of the GEX input file. Options are "10x_mtx"'
                    ' for a 10x directory with .mtx and associated files;'
                    ' "10x_h5" for a 10x HDF5 formatted file; and "h5ad"'
                    ' for a scanpy formatted hdf5 file')
parser.add_argument('--clones_file',
                    help='tsv-formatted clonotype file generated by'
                    ' setup_10x_for_conga.py (for example)')
parser.add_argument('--organism',
                    choices=['mouse', 'human', 'mouse_gd', 'human_gd',
                             'human_ig'])
parser.add_argument('--nbr_fracs', type=float, nargs='*', default=[0.01,0.1],
                    help='Size of neighborhoods to use in building K'
                    ' nearest neighbor graphs, expressed as a fraction'
                    ' of the total dataset size in clonotypes. Default values'
                    ' are 0.01 and 0.1')
parser.add_argument('--outfile_prefix',
                    help='string that will be prepended to all output files'
                    ' and images')


# the main modes of operation
parser.add_argument('--all', action='store_true',
                    help='Run all standard analyses')
parser.add_argument('--graph_vs_graph', action='store_true')
parser.add_argument('--graph_vs_graph_stats', action='store_true')
parser.add_argument('--graph_vs_features', action='store_true')
parser.add_argument('--match_to_tcr_database', action='store_true',
                    help='Find significant matches to paired tcrs in the'
                    ' database specified by --tcr_database_tsvfile (default'
                    ' is the dataset in'
                    ' conga/data/new_paired_tcr_db_for_matching_nr.tsv')
parser.add_argument('--tcr_database_tsvfile',
                    help='Must have columns va cdr3a vb cdr3b, minimally;'
                    ' with imgt-recognized allele names; default is'
                    ' conga/data/new_paired_tcr_db_for_matching_nr.tsv')
parser.add_argument('--tcr_clumping', action='store_true')
parser.add_argument('--find_hotspot_features', action='store_true')


## minor analysis modes
parser.add_argument('--make_tcrdist_trees', action='store_true')
parser.add_argument('--cluster_vs_cluster', action='store_true')
parser.add_argument('--find_gex_cluster_degs', action='store_true')
parser.add_argument('--find_batch_biases', action='store_true')
parser.add_argument('--calc_clone_pmhc_pvals', action='store_true')
parser.add_argument('--find_pmhc_nbrhood_overlaps', action='store_true')
parser.add_argument('--find_distance_correlations', action='store_true')
parser.add_argument('--find_hotspot_nbrhoods', action='store_true')
parser.add_argument('--plot_cluster_gene_compositions', action='store_true')
parser.add_argument('--analyze_CD4_CD8', action='store_true')
parser.add_argument('--analyze_proteins', action='store_true')
parser.add_argument('--analyze_special_genes', action='store_true')

# options for subsetting the data
parser.add_argument('--exclude_gex_clusters', type=int, nargs='*')
parser.add_argument('--exclude_mait_and_inkt_cells', action='store_true')
parser.add_argument('--subset_to_CD4', action='store_true')
parser.add_argument('--subset_to_CD8', action='store_true')
parser.add_argument('--subset_to_CD4_cells', action='store_true') # tmp testing
parser.add_argument('--subset_to_CD8_cells', action='store_true') # tmp testing
parser.add_argument('--bad_barcodes_file')
parser.add_argument('--exclude_vgene_strings', type=str, nargs='*')

# options for configuring parameters
parser.add_argument('--average_clone_gex', action='store_true',
                    help='Average GEX PCs over all members of a clonotype '
                    'instead of picking a single representative cell.')
parser.add_argument('--min_cluster_size', type=int, default=5)
parser.add_argument('--min_cluster_size_fraction', type=float, default=0.001)
parser.add_argument('--min_cluster_size_for_tcr_clumping_logos', type=int,
                    default=3)
parser.add_argument('--min_nbrhood_size', type=int)
parser.add_argument('--pvalue_threshold_for_db_matching', type=float,
                    default=1.0)
parser.add_argument('--pvalue_threshold_for_tcr_clumping', type=float,
                    default=1.0)
parser.add_argument('--num_random_samples_for_tcr_clumping', type=int,
                    default=50000)
parser.add_argument('--num_random_samples_for_tcr_matching', type=int,
                    default=50000)
parser.add_argument('--clustering_method', choices=['louvain','leiden'])
parser.add_argument('--clustering_resolution', type=float, default = 1.0)
parser.add_argument('--make_hotspot_raw_feature_plots', action='store_true',
                    help='The default is just to plot the nbrhood-averaged'
                    ' values')
parser.add_argument('--analyze_junctions', action='store_true')
parser.add_argument('--radii_for_tcr_clumping', type=int, nargs='*')
parser.add_argument('--qc_plots', action='store_true')
parser.add_argument('--max_clones_for_clustermaps', type=int, default=20000,
                    help='Currently the clustermapping code computes the full'
                    ' pairwise matrix of distances for hierarchical clustering.'
                    ' This can get slow and memory intensive, so limit the'
                    ' dataset size for these plots.')

# configure the logo plots and some other plots
parser.add_argument('--skip_gex_header', action='store_true')
parser.add_argument('--skip_gex_header_raw', action='store_true')
parser.add_argument('--skip_gex_header_nbrZ', action='store_true')
parser.add_argument('--skip_tcr_scores_in_gex_header', action='store_true')
parser.add_argument('--include_alphadist_in_tcr_feature_logos',
                    action='store_true')
parser.add_argument('--show_pmhc_info_in_logos', action='store_true',
                    help='[DEV] This only works if you have properly formatted '
                    'pMHC binding data in your counts matrix.')
parser.add_argument('--gex_header_tcr_score_names', type=str, nargs='*')
parser.add_argument('--gex_logo_genes', type=str, nargs='*')
parser.add_argument('--gex_header_genes', type=str, nargs='*')
parser.add_argument('--gex_nbrhood_tcr_score_names', type=str, nargs='*')
parser.add_argument('--dont_show_lit_matches_in_logos', action='store_true')
parser.add_argument('--short_clustermaps', action='store_true')


# preprocessing options
parser.add_argument('--max_genes_per_cell', type=int)
parser.add_argument('--min_genes_per_cell', type=int)
parser.add_argument('--max_percent_mito', type=float)

# if your input AnnData file has integer-valued columns defining
#  batch/tissue/donor/etc you can pass the column names with this option
#  and it will add 'batch' information to the logo plots
parser.add_argument('--batch_keys', type=str, nargs='*')

# options to use exact tcrdist neighbors/umap/clusters rathen than KPCA
parser.add_argument('--no_kpca', action='store_true')
parser.add_argument('--use_exact_tcrdist_nbrs', action='store_true',
                    help='The default is to use the nbrs defined by'
                    ' euclidean distances in the tcrdist kernel pc space.'
                    ' This flag will force a re-computation of all the tcrdist'
                    ' distances')
parser.add_argument('--use_tcrdist_umap', action='store_true')
parser.add_argument('--use_tcrdist_clusters', action='store_true')


# some random deprecated/under-development arguments
parser.add_argument('--intra_cluster_tcr_clumping', action='store_true')
parser.add_argument('--make_hotspot_nbrhood_logos', action='store_true')
parser.add_argument('--include_protein_features', action='store_true')
parser.add_argument('--verbose_nbrs', action='store_true')
parser.add_argument('--tenx_agbt', action='store_true')
parser.add_argument('--exclude_batch_keys_for_biases', type=str, nargs='*')
parser.add_argument('--shuffle_tcr_kpcs', action='store_true',
                    help='shuffle the TCR kpcs to test for FDR')
parser.add_argument('--shuffle_gex_nbrs', action='store_true')
parser.add_argument('--suffix_for_non_gene_features', type=str)
parser.add_argument('--min_cluster_size_for_batch_bias_logos', type=int,
                    default=5)
parser.add_argument('--make_clone_plots', action='store_true')
parser.add_argument('--write_proj_info', action='store_true')
parser.add_argument('--filter_ribo_norm_low_cells', action='store_true')


# kernel pca args, for I/O or if we are recomputing
parser.add_argument('--kpca_file',
                    help='Pass filename if using a non-standard location'
                    ' (ie not clones_file[:-4]+\'_AB.dist_50_kpcs\')')
parser.add_argument('--restart',
                    help='Name of a scanpy h5ad file to restart from; skips'
                    ' preprocessing, clustering, UMAP, etc. Could be the'
                    ' *_final.h5ad file generated at the end of a previous'
                    ' conga run.')
parser.add_argument('--rerun_kpca', action='store_true')
parser.add_argument('--kpca_kernel',
                    help='only used if rerun_kpca is True; if not provided'
                    ' will use classic kernel')
parser.add_argument('--kpca_gaussian_kernel_sdev', default=100.0, type=float,
                    help='only used if rerun_kpca and kpca_kernel=\'gaussian\'')
parser.add_argument('--kpca_default_kernel_Dmax', type=float,
                    help='only used if rerun_kpca and kpca_kernel==None')

# option to save a checkpoint file for looking at umaps, clusters during analysis
parser.add_argument('--checkpoint', action='store_true',
                    help='Save a scanpy h5ad checkpoint file after'
                    ' preprocessing')

args = parser.parse_args()

if len(sys.argv)==1:
    parser.print_help()
    sys.exit()

# update args specified in yml file
if args.config is not None:
    import yaml
    assert exists(args.config)
    yml_args = yaml.load(open(args.config), Loader=yaml.FullLoader)
    for k, v in yml_args.items():
        if k in args.__dict__:
            args.__dict__[k] = v
        else:
            sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))

if args.outfile_prefix is None:
    print('Prefix for output files not specified. Add to --config file or specify with --outfile_prefix')
    sys.exit(1)

# do the imports now since they are so freakin slow
from collections import Counter, OrderedDict
import time
# in order to import conga package
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import matplotlib
matplotlib.use('Agg') # for remote calcs
import matplotlib.pyplot as plt
import conga
from conga import util
import scanpy as sc
import scanpy.neighbors
from sklearn.metrics import pairwise_distances
import numpy as np
import pandas as pd
from pathlib import Path

start_time = time.time()


#############################################################################
###
### misc arg parsing
###
#############################################################################

if args.gex_nbrhood_tcr_score_names is None:
    args.gex_nbrhood_tcr_score_names = conga.tcr_scoring.all_tcr_scorenames[:]

if args.all:
    all_modes = """graph_vs_graph
    graph_vs_graph_stats
    graph_vs_features
    cluster_vs_cluster
    find_hotspot_features
    find_gex_cluster_degs
    tcr_clumping
    match_to_tcr_database
    make_tcrdist_trees""".split()

    for mode in all_modes:
        print(f'--all implies --{mode} ==> Running {mode} analysis.')
        setattr(args, mode, True)

if args.no_kpca:
    print('--no_kpca implies --use_exact_tcrdist_nbrs and --use_tcrdist_umap '
          'and --use_tcrdist_clusters')
    print('setting those flags now')
    args.use_exact_tcrdist_nbrs = True
    args.use_tcrdist_umap = True
    args.use_tcrdist_clusters = True


## check consistency of args
# if args.find_pmhc_nbrhood_overlaps or args.calc_clone_pmhc_pvals:
#     # we need pmhc info for these analyses; right now that's restricted
#     # to the 10x AGBT dataset format
#     assert 'pmhc_var_names' in adata.uns or args.tenx_agbt

if args.batch_keys:
    assert args.gex_data_type == 'h5ad' # need the info already in the obs dict

if args.restart: # these are incompatible with restarting
     assert not (args.calc_clone_pmhc_pvals or
                 args.bad_barcodes_file or
                 args.filter_ribo_norm_low_cells or
                 args.exclude_vgene_strings or
                 args.subset_to_CD4_cells or
                 args.subset_to_CD8_cells or
                 #args.shuffle_tcr_kpcs or
                 args.rerun_kpca )

logfile = args.outfile_prefix+'_log.txt'
outlog = open(logfile, 'w')
outlog.write('sys.argv: {}\n'.format(' '.join(sys.argv)))
try: # scanpy changed and this doesn't seem to work anymore
    sc.logging.print_versions()
except:
    sc.logging.print_header()
hostname = os.popen('hostname').readlines()[0][:-1]
outlog.write('hostname: {}\n'.format(hostname))

#############################################################################
###
### I/O --- reading GEX/TCR or restarting from a previous run
###
#############################################################################


if args.restart is None: ################################## load GEX/TCR data
    allow_missing_kpca_file = (
        args.use_exact_tcrdist_nbrs and
        args.use_tcrdist_umap and
        args.use_tcrdist_clusters
        )

    assert exists(args.gex_data)

    # adding possibility of getting tcr info from adata
    if args.clones_file is None:
        print('WARNING --clones_file flag is missing,'
              ' h5ad file must already contain TCR info')

    ## load the dataset
    if args.rerun_kpca and args.clones_file is not None:
        if args.kpca_file is None:
            args.kpca_file = args.outfile_prefix+'_rerun_tcrdist_kpca.txt'
        else:
            print('WARNING:: overwriting', args.kpca_file,
                  'since --rerun_kpca is True')
        conga.preprocess.make_tcrdist_kernel_pcs_file_from_clones_file(
            args.clones_file,
            args.organism,
            kernel=args.kpca_kernel,
            outfile=args.kpca_file,
            gaussian_kernel_sdev = args.kpca_gaussian_kernel_sdev,
            force_Dmax = args.kpca_default_kernel_Dmax,
        )

    adata = conga.preprocess.read_dataset(
        args.gex_data,
        args.gex_data_type,
        args.clones_file,
        kpca_file = args.kpca_file, # default is None
        allow_missing_kpca_file = allow_missing_kpca_file,
        gex_only = False,
        suffix_for_non_gene_features = args.suffix_for_non_gene_features,
    )

    if args.rerun_kpca and args.clones_file is None:
        # do this (compute tcrdist kernel pcs) now rather than before
        #   dataset loading if we don't have a clones file
        tcrs = conga.preprocess.retrieve_tcrs_from_adata(adata)
        kpcs = conga.preprocess.make_tcrdist_kernel_pcs_file_from_clones_file(
            None,
            args.organism,
            kernel=args.kpca_kernel,
            outfile=args.kpca_file,
            gaussian_kernel_sdev=args.kpca_gaussian_kernel_sdev,
            force_Dmax=args.kpca_default_kernel_Dmax,
            tcrs=tcrs,
            return_pcs = True,
        )
        adata.obsm['X_pca_tcr'] = kpcs


    assert args.organism
    adata.uns['organism'] = args.organism
    assert 'organism' in adata.uns_keys()
    if args.batch_keys:
        adata.uns['batch_keys'] = args.batch_keys
        for k in args.batch_keys:
            assert k in adata.obs_keys()
            # confirm integer-value
            vals = np.array(adata.obs[k]).astype(int)
            counts = Counter(vals)
            expected_choices = np.max(vals)+1
            observed_choices = len(counts.keys())
            print(f'read batch info for key {k} with {expected_choices}'
                  f' possible and {observed_choices} observed choices')
            adata.obs[k] = vals
    elif 'batch_keys' in adata.uns:
        old_batch_keys = list(adata.uns['batch_keys'])
        if not all(x in adata.obs_keys() for x in old_batch_keys):
            print('warning: dropping some of the batch_keys not present',
                  'in adata.obs, old_batch_keys=', old_batch_keys,
                  'obs_keys=', adata.obs_keys())
        new_batch_keys = [x for x in old_batch_keys if x in adata.obs_keys()]
        if new_batch_keys:
            adata.uns['batch_keys'] = new_batch_keys
        else:
            del adata.uns['batch_keys']

    if args.exclude_vgene_strings:
        tcrs = conga.preprocess.retrieve_tcrs_from_adata(adata)
        exclude_mask = np.full((adata.shape[0],),False)
        for s in args.exclude_vgene_strings:
            mask = np.array([s in x[0][0] or s in x[1][0] for x in tcrs])
            print('exclude_vgene_strings:', s, 'num_matches:', np.sum(mask))
            exclude_mask |= mask
        adata = adata[~exclude_mask].copy()

    if args.exclude_mait_and_inkt_cells:
        tcrs = conga.preprocess.retrieve_tcrs_from_adata(adata)
        if args.organism == 'human':
            mask = [ not (conga.tcr_scoring.is_human_mait_alpha_chain(x[0]) or
                          conga.tcr_scoring.is_human_inkt_tcr(x))
                     for x in tcrs ]
        elif args.organism == 'mouse':
            mask = [ not (conga.tcr_scoring.is_mouse_mait_alpha_chain(x[0]) or
                          conga.tcr_scoring.is_mouse_inkt_alpha_chain(x[0]))
                     for x in tcrs ]
        else:
            print('ERROR: --exclude_mait_and_inkt_cells option is only'
                  ' compatible with a/b tcrs, but organism is not "human"'
                  ' or "mouse"')
            sys.exit(1)
        print('excluding {} mait/inkt cells from dataset of size {}'\
              .format(adata.shape[0]-np.sum(mask), adata.shape[0]))
        adata = adata[mask].copy()


    if args.tenx_agbt:
        conga.pmhc_scoring.shorten_pmhc_var_names(adata)

        adata.uns['pmhc_var_names'] \
            = conga.pmhc_scoring.get_tenx_agbt_pmhc_var_names(adata)

        print('pmhc_var_names:', adata.uns['pmhc_var_names'])

    if args.bad_barcodes_file:
        bad_barcodes = frozenset([x[:-1] for x in open(args.bad_barcodes_file,
                                                       'rU')])
        bad_bc_mask = np.array( [x in bad_barcodes for x in adata.obs_names ] )
        num_bad = np.sum(bad_bc_mask)
        if num_bad:
            print('excluding {} bad barcodes found in {}'\
                  .format(num_bad, args.bad_barcodes_file))
            adata = adata[~bad_bc_mask,:].copy()
        else:
            print('WARNING:: no matched barcodes in bad_barcodes_file:',
                  args.bad_barcodes_file)


    # is the tcr-dist kPCA info present?
    assert allow_missing_kpca_file or 'X_pca_tcr' in adata.obsm_keys()
    assert 'cdr3a' in adata.obs # tcr sequence (VDJ) info (plus other obs keys)

    print(adata)

    outfile_prefix_for_qc_plots = None if args.qc_plots is None else \
                                  args.outfile_prefix
    adata = conga.preprocess.filter_and_scale(
        adata,
        max_genes_per_cell = args.max_genes_per_cell,
        min_genes_per_cell = args.min_genes_per_cell,
        max_percent_mito = args.max_percent_mito,
        outfile_prefix_for_qc_plots = outfile_prefix_for_qc_plots )

    if args.filter_ribo_norm_low_cells:
        # this is sketchy
        adata = conga.devel.filter_cells_by_ribo_norm( adata )

    if args.calc_clone_pmhc_pvals:
        # do this before condensing to a single clone per cell
        # note that we are doing this after filtering out the ribo-low cells
        results_df = conga.pmhc_scoring.calc_clone_pmhc_pvals(adata)
        tsvfile = args.outfile_prefix+'_clone_pvals.tsv'
        print('making:', tsvfile)
        results_df.to_csv(tsvfile, sep='\t', index=False)

    if args.make_clone_plots:
        # need to compute cluster and umaps for these plots
        # these will be re-computed once we reduce to a single
        # cell per clonotype
        print('make_clone_plots: cluster_and_tsne_and_umap')
        adata = conga.preprocess.cluster_and_tsne_and_umap(
            adata, skip_tcr=True)

        conga.plotting.make_clone_gex_umap_plots(adata, args.outfile_prefix)

    if args.subset_to_CD4_cells or args.subset_to_CD8_cells:
        adata_cd4, adata_cd8 = conga.devel.split_into_cd4_and_cd8_subsets(
            adata, verbose= True)
        if args.subset_to_CD4_cells:
            adata = adata_cd4
        else:
            adata = adata_cd8


    print('run reduce_to_single_cell_per_clone'); sys.stdout.flush()
    adata = conga.preprocess.reduce_to_single_cell_per_clone(
        adata, average_clone_gex=args.average_clone_gex )

    if args.include_protein_features:
        # this fills X_pca_gex_only, X_pca_gex (combo), X_pca_prot
        # in the adata.obsm array
        conga.preprocess.calc_X_pca_gex_including_protein_features(
            adata, compare_distance_distributions=True)

    if args.shuffle_tcr_kpcs:
        X_pca_tcr = adata.obsm['X_pca_tcr']
        assert X_pca_tcr.shape[0] == adata.shape[0]
        reorder = np.random.permutation(X_pca_tcr.shape[0])
        adata.obsm['X_pca_tcr'] = X_pca_tcr[reorder,:]
        outlog.write('randomly permuting X_pca_tcr {}\n'\
                     .format(X_pca_tcr.shape))

    clustering_resolution = (2.0 if (args.subset_to_CD8 or args.subset_to_CD4)
                             else args.clustering_resolution)

    print('run cluster_and_tsne_and_umap'); sys.stdout.flush()
    adata = conga.preprocess.cluster_and_tsne_and_umap(
        adata, clustering_resolution = clustering_resolution,
        clustering_method=args.clustering_method)

    ###########################################################################
else: ### restarting from a previous conga run
    ###########################################################################

    assert exists(args.restart)
    adata = sc.read_h5ad(args.restart)
    print('recover from h5ad file:', args.restart, adata )

    if 'organism' not in adata.uns_keys():
        assert args.organism
        adata.uns['organism'] = args.organism

    util.setup_uns_dicts(adata)

    if args.exclude_mait_and_inkt_cells and not args.exclude_gex_clusters:
        # should move this code into a helper function in conga!
        organism = adata.uns['organism']
        tcrs = conga.preprocess.retrieve_tcrs_from_adata(adata)
        if organism == 'human':
            mask = [not (conga.tcr_scoring.is_human_mait_alpha_chain(x[0]) or
                         conga.tcr_scoring.is_human_inkt_tcr(x))
                    for x in tcrs ]
        elif organism == 'mouse':
            mask = [not (conga.tcr_scoring.is_mouse_mait_alpha_chain(x[0]) or
                         conga.tcr_scoring.is_mouse_inkt_alpha_chain(x[0]))
                    for x in tcrs ]
        else:
            print('ERROR: --exclude_mait_and_inkt_cells option is only'
                  ' compatible with a/b tcrs but organism is not "human" or'
                  ' "mouse"')
            sys.exit(1)
        print('excluding {} mait/inkt cells from dataset of size {}'\
              .format(adata.shape[0]-np.sum(mask), adata.shape[0]))
        adata = adata[mask].copy()
        # need to redo the cluster/tsne/umap
        adata = conga.preprocess.cluster_and_tsne_and_umap(
            adata, clustering_method=args.clustering_method,
            clustering_resolution=args.clustering_resolution)


    if args.shuffle_tcr_kpcs:
        # shuffle the kpcs and anything derived from them that is relevant
        #  to GvG (this is just for testing)
        # NOTE: we need to add shuffling of the neighbors if we are going
        #  to recover nbr info rather than recomputing...
        X_pca_tcr = adata.obsm['X_pca_tcr']
        assert X_pca_tcr.shape[0] == adata.shape[0]
        reorder = np.random.permutation(X_pca_tcr.shape[0])
        adata.obsm['X_pca_tcr'] = X_pca_tcr[reorder,:]
        adata.obs['clusters_tcr'] = np.array(adata.obs['clusters_tcr'])[reorder]
        adata.obsm['X_tcr_2d'] = np.array(adata.obsm['X_tcr_2d'])[reorder,:]
        print('shuffle_tcr_kpcs:: shuffled X_pca_tcr, clusters_tcr, and'
              ' X_tcr_2d')
        outlog.write(f'randomly permuting X_pca_tcr {X_pca_tcr.shape}\n')


if 'batch_keys' in adata.uns_keys():
    # sometimes if there's a single batch key the type changes from a list to
    # just the single string when we save h5ad and then reload
    batch_keys = adata.uns['batch_keys']
    if ( batch_keys[0] not in adata.obsm_keys() and
         batch_keys in adata.obsm_keys()):
        print('update adata.uns["batch_keys"] from str to list')
        adata.uns['batch_keys'] = [adata.uns['batch_keys']]


if args.exclude_gex_clusters:
    xl = args.exclude_gex_clusters
    clusters_gex = np.array(adata.obs['clusters_gex'])
    mask = (clusters_gex==xl[0])
    for c in xl[1:]:
        mask |= (clusters_gex==c)
    print('exclude_gex_clusters: exclude {} cells in {} clusters: {}'\
          .format(np.sum(mask), len(xl), xl))
    sys.stdout.flush()
    adata = adata[~mask,:].copy()

    if args.exclude_mait_and_inkt_cells:
        organism = adata.uns['organism']
        tcrs = conga.preprocess.retrieve_tcrs_from_adata(adata)
        if organism == 'human':
            mask = [not (conga.tcr_scoring.is_human_mait_alpha_chain(x[0]) or
                         conga.tcr_scoring.is_human_inkt_tcr(x))
                    for x in tcrs ]
        elif organism == 'mouse':
            mask = [not (conga.tcr_scoring.is_mouse_mait_alpha_chain(x[0]) or
                         conga.tcr_scoring.is_mouse_inkt_alpha_chain(x[0]))
                    for x in tcrs ]
        else:
            print('ERROR: --exclude_mait_and_inkt_cells option is only'
                  ' compatible with a/b tcrs but organism is not "human" or'
                  ' "mouse"')
            sys.exit(1)
        print('excluding {} mait/inkt cells from dataset of size {}'\
              .format(adata.shape[0]-np.sum(mask), adata.shape[0]))
        adata = adata[mask].copy()

    adata = conga.preprocess.cluster_and_tsne_and_umap(
        adata, clustering_method=args.clustering_method,
        clustering_resolution=args.clustering_resolution)

if args.subset_to_CD4 or args.subset_to_CD8:
    assert not (args.subset_to_CD4 and args.subset_to_CD8)
    which_subset = 'CD4' if args.subset_to_CD4 else 'CD8'
    adata = conga.preprocess.subset_to_CD4_or_CD8_clusters(
        adata, which_subset, use_protein_features=args.include_protein_features)

    adata = conga.preprocess.cluster_and_tsne_and_umap(
        adata, clustering_method=args.clustering_method,
        clustering_resolution=args.clustering_resolution)

# this will probably not happen anymore, since we are computing these
# inside preprocess.cluster_and_tsne_and_umap if X_pca_tcr is missing.
# Special case is if we have the kernel PCS but we still want to use
# UMAP/clusters based on exact tcrdists...
need_to_compute_tcrdist_umap = (
    'X_tcr_2d' not in adata.obsm_keys() or # missing
    (args.use_tcrdist_umap and 'X_pca_tcr' in adata.obsm_keys())) #recompute

need_to_compute_tcrdist_clusters = (
    'clusters_tcr' not in adata.obs_keys() or # missing
    (args.use_tcrdist_clusters and 'X_pca_tcr' in adata.obsm_keys())) #recompute

if need_to_compute_tcrdist_umap or need_to_compute_tcrdist_clusters:
    umap_key_added = 'X_tcr_2d' if need_to_compute_tcrdist_umap else \
                     'X_tcrdist_2d'
    cluster_key_added = 'clusters_tcr' if need_to_compute_tcrdist_clusters else\
                        'clusters_tcrdist'
    num_nbrs = 10
    conga.preprocess.calc_tcrdist_nbrs_umap_clusters_cpp(
        adata, num_nbrs,
        tmpfile_prefix=args.outfile_prefix,
        umap_key_added=umap_key_added,
        cluster_key_added=cluster_key_added)

# optionally save a checkpoint h5-formatted AnnData object
if args.checkpoint:
    adata.write_h5ad(args.outfile_prefix+'_checkpoint.h5ad')

###############################################################################
###
### DONE WITH I/O, now do some setup, calculate neighbor graphs, etc
###
###############################################################################


# all_nbrs is dict from nbr_frac to [nbrs_gex, nbrs_tcr]
# for nndist calculations, use a smallish nbr_frac, but not too small:
num_clones = adata.shape[0]

# adjust nbr_fracs if necessary
if args.min_nbrhood_size is not None:
    min_nbr_frac = args.min_nbrhood_size/num_clones
    min_nbr_frac = int(1000*min_nbr_frac)/1000. # dont need all the precision
    old_nbr_fracs = args.nbr_fracs[:]
    args.nbr_fracs = sorted(set(max(min_nbr_frac,x) for x in args.nbr_fracs))
    if args.nbr_fracs != old_nbr_fracs:
        print('adjusted nbr_fracs:', args.min_nbrhood_size, num_clones,
              old_nbr_fracs, args.nbr_fracs)

nbr_frac_for_nndists = min( x for x in args.nbr_fracs
                            if x*num_clones>=10 or x==max(args.nbr_fracs) )
outlog.write(f'nbr_frac_for_nndists: {nbr_frac_for_nndists}\n')
adata.uns['conga_stats']['nbr_frac_for_nndists'] = nbr_frac_for_nndists

obsm_tag_tcr = None if args.use_exact_tcrdist_nbrs else 'X_pca_tcr'
all_nbrs, nndists_gex, nndists_tcr = conga.preprocess.calc_nbrs(
    adata,
    args.nbr_fracs,
    also_calc_nndists = True,
    nbr_frac_for_nndists = nbr_frac_for_nndists,
    obsm_tag_tcr = obsm_tag_tcr,
    use_exact_tcrdist_nbrs = args.use_exact_tcrdist_nbrs,
)


#
if args.analyze_junctions:
    tcrs = conga.preprocess.retrieve_tcrs_from_adata(adata)
    new_tcrs = conga.tcrdist.tcr_sampler.find_alternate_alleles_for_tcrs(
        adata.uns['organism'], tcrs, verbose=False)
    junctions_df = conga.tcrdist.tcr_sampler.parse_tcr_junctions(
        adata.uns['organism'], new_tcrs)

    num_inserts = (np.array(junctions_df.a_insert) +
                   np.array(junctions_df.vd_insert) +
                   np.array(junctions_df.dj_insert) +
                   np.array(junctions_df.vj_insert))
    adata.obs['N_ins'] = num_inserts
    args.gex_nbrhood_tcr_score_names.append('N_ins')


if args.shuffle_gex_nbrs:
    reorder = np.random.permutation(num_clones)
    print('shuffling gex nbrs: num_shuffle_fixed_points=',
          np.sum(reorder==np.arange(num_clones)))
    reorder_list = list(reorder)
    # reorder maps from the old index to the permuted index,
    #  ie new_i = reorder[old_i]

    for nbr_frac in args.nbr_fracs:
        old_nbrs = all_nbrs[nbr_frac][0]
        new_nbrs = []
        for new_ii in range(num_clones): # the new index
            old_ii = reorder_list.index(new_ii)
            new_nbrs.append( [ reorder[x] for x in old_nbrs[old_ii]])
        all_nbrs[nbr_frac] = [np.array(new_nbrs), all_nbrs[nbr_frac][1]]


# stash these in obs array, they are used in a few places...
adata.obs['nndists_gex'] = nndists_gex
adata.obs['nndists_tcr'] = nndists_tcr
conga.preprocess.setup_tcr_cluster_names(adata) #stores in adata.uns


if args.verbose_nbrs:
    for nbr_frac in args.nbr_fracs:
        for tag, nbrs in [['gex', all_nbrs[nbr_frac][0]],
                          ['tcr', all_nbrs[nbr_frac][1]]]:
            outfile = '{}_{}_nbrs_{:.3f}.txt'\
                      .format(args.outfile_prefix, tag, nbr_frac)
            np.savetxt(outfile, nbrs, fmt='%d')
            print('wrote nbrs to file:', outfile)


###############################################################################
###
### Now run the different modes of analysis requested by the user
###
###############################################################################

if (args.match_to_tcr_database and
    (args.tcr_database_tsvfile or adata.uns['organism'] == 'human')):
    # we only have a built-in database for human alpha-beta tcrs right now

    # this function
    #  * returns results as a pandas DataFrame
    #  * saves the results in adata.uns['conga_results'][TCR_DB_MATCH]
    #  * writes the results to a tsvfile (since we passed outfile_prefix)
    conga.tcr_clumping.match_adata_tcrs_to_db_tcrs(
        adata,
        db_tcrs_tsvfile= args.tcr_database_tsvfile,
        outfile_prefix= args.outfile_prefix, # save results as tsvfile
        tmpfile_prefix= args.outfile_prefix,
        num_random_samples_for_bg_freqs=
            args.num_random_samples_for_tcr_matching, # long line
        adjusted_pvalue_threshold= args.pvalue_threshold_for_db_matching
    )

    conga.plotting.make_tcr_db_match_plot(
        adata,
        args.outfile_prefix,
    )


if args.tcr_clumping: #########################################################
    num_random_samples = args.num_random_samples_for_tcr_clumping

    radii = [24, 48, 72, 96] if args.radii_for_tcr_clumping is None else \
            args.radii_for_tcr_clumping

    # results are stored in adata.uns['conga_results'][TCR_CLUMPING]
    # and also returned by this function:
    conga.tcr_clumping.assess_tcr_clumping(
        adata,
        outfile_prefix= args.outfile_prefix, # will save results as .tsv file
        tmpfile_prefix= args.outfile_prefix,
        radii= radii,
        num_random_samples= num_random_samples,
        pvalue_threshold= args.pvalue_threshold_for_tcr_clumping,
        also_find_clumps_within_gex_clusters= args.intra_cluster_tcr_clumping,
    )

    nbrs_gex, nbrs_tcr = all_nbrs[ max(args.nbr_fracs) ]

    # now call plotting function, after results are stashed in adata
    conga.plotting.make_tcr_clumping_plots(
        adata,
        nbrs_gex,
        nbrs_tcr,
        args.outfile_prefix,
        min_cluster_size_for_logos=args.min_cluster_size_for_tcr_clumping_logos,
        pvalue_threshold_for_logos=args.pvalue_threshold_for_tcr_clumping,
        )


if args.graph_vs_graph_stats: #################################################
    conga.correlations.compute_graph_vs_graph_stats(
        adata, all_nbrs, num_random_repeats=100,
        outfile_prefix=args.outfile_prefix)

    # cols = ('graph_overlap_type nbr_frac overlap expected_overlap'
    #         ' overlap_zscore').split()


if args.graph_vs_graph: #######################################################
    # run the graph vs graph analysis
    # this fxn returns results as pandas DataFrame,
    #  stashes them in adata.uns['conga_results'][GRAPH_VS_GRAPH]
    #  and saves them to a tsvfile (since we pass outfile_prefix argument)
    conga.correlations.run_graph_vs_graph(
        adata, all_nbrs, verbose=args.verbose_nbrs,
        outfile_prefix=args.outfile_prefix)


    # take the LARGER of the two min_cluster_size thresholds
    min_cluster_size = max(
        args.min_cluster_size,
        int( 0.5 + args.min_cluster_size_fraction * num_clones))

    nbrs_gex, nbrs_tcr = all_nbrs[max(args.nbr_fracs)]
    gex_header_tcr_score_names = [] if args.skip_tcr_scores_in_gex_header \
                                 else None

    conga.plotting.make_graph_vs_graph_logos(
        adata,
        args.outfile_prefix,
        min_cluster_size,
        nbrs_gex,
        nbrs_tcr,
        gex_nbrhood_tcr_score_names=args.gex_nbrhood_tcr_score_names,
        gex_header_tcr_score_names=gex_header_tcr_score_names,
        make_gex_header = not args.skip_gex_header,
        make_gex_header_raw = not args.skip_gex_header_raw,
        make_gex_header_nbrZ = not args.skip_gex_header_nbrZ,
        include_alphadist_in_tcr_feature_logos =
            args.include_alphadist_in_tcr_feature_logos,
        show_pmhc_info_in_logos = args.show_pmhc_info_in_logos,
        logo_genes=args.gex_logo_genes,
        gex_header_genes=args.gex_header_genes,
    )

if args.graph_vs_features:
    # compute and store the tables in adata.uns['conga_results']
    conga.correlations.run_graph_vs_features(
        adata, all_nbrs, outfile_prefix=args.outfile_prefix)

    # make the plots
    clustermap_max_type_features = 25 if args.short_clustermaps else 50
    conga.plotting.make_graph_vs_features_plots(
        adata, all_nbrs, args.outfile_prefix,
        clustermap_max_type_features=clustermap_max_type_features,
    )


if args.graph_vs_graph and args.graph_vs_features: #########################
    conga.plotting.make_summary_figure(adata, args.outfile_prefix)

## some extra analyses

if args.make_tcrdist_trees: ###################################################
    # make tcrdist trees for each of the gex clusters,
    conga.plotting.make_tcrdist_trees(
        adata, args.outfile_prefix, group_by = 'clusters_gex')


    conga.plotting.make_tcrdist_tree_for_conga_score_threshold(
        adata, 10., args.outfile_prefix)


if args.cluster_vs_cluster:
    tcrs = conga.preprocess.retrieve_tcrs_from_adata(adata)
    clusters_gex = np.array(adata.obs['clusters_gex'])
    clusters_tcr = np.array(adata.obs['clusters_tcr'])
    barcodes = list(adata.obs_names)
    barcode2tcr = dict(zip(barcodes,tcrs))
    conga.devel.compute_cluster_interactions(
        clusters_gex, clusters_tcr, barcodes, barcode2tcr, outlog )

if args.plot_cluster_gene_compositions:
    pngfile = args.outfile_prefix+'_cluster_gene_compositions.png'
    conga.plotting.plot_cluster_gene_compositions(adata, pngfile)


if args.find_gex_cluster_degs:
    # look at differentially expressed genes in gex clusters

    conga.devel.find_gex_cluster_degs(adata, args.outfile_prefix)


if args.find_hotspot_features: ################################################
    # My hacky and probably buggy first implementation of the HotSpot method:
    #
    # "Identifying Informative Gene Modules Across Modalities of
    #  Single Cell Genomics"
    # David DeTomaso, Nir Yosef
    # https://www.biorxiv.org/content/10.1101/2020.02.06.937805v1

    conga.correlations.find_hotspots_wrapper(
        adata, all_nbrs, outfile_prefix=args.outfile_prefix)

    clustermap_max_type_features = 25 if args.short_clustermaps else 50
    conga.plotting.make_hotspot_plots(
        adata, all_nbrs, args.outfile_prefix,
        make_raw_feature_plots=args.make_hotspot_raw_feature_plots,
        max_clones_for_dendrograms=args.max_clones_for_clustermaps,
        clustermap_max_type_features=clustermap_max_type_features,
        )

if args.find_hotspot_nbrhoods:
    conga.devel.find_hotspot_nbrhoods(
        adata, all_nbrs, args.outfile_prefix,
        make_hotspot_nbrhood_logos=args.make_hotspot_nbrhood_logos,
        min_cluster_size=args.min_cluster_size,
        min_cluster_size_fraction=args.min_cluster_size_fraction,
        )


if args.analyze_CD4_CD8:
    min_nbrs = 10
    for nbr_frac in sorted(all_nbrs.keys()):
        if nbr_frac * adata.shape[0] > min_nbrs:
            nbr_frac_for_plotting = nbr_frac
            break
    else:
        nbr_frac_for_plotting = max(all_nbrs.keys())

    conga.devel.analyze_CD4_CD8(
        adata, all_nbrs[nbr_frac_for_plotting][0], args.outfile_prefix)

if args.analyze_proteins:
    conga.devel.analyze_proteins(adata, args.outfile_prefix)

if args.analyze_special_genes:
    conga.devel.analyze_special_genes(adata, args.outfile_prefix)


batch_bias_results = None
if args.find_batch_biases: #####################################################
    pval_threshold = 0.05 # kind of arbitrary
    nbrhood_results, hotspot_results = conga.devel.find_batch_biases(
        adata,
        all_nbrs,
        pval_threshold=pval_threshold,
        exclude_batch_keys=args.exclude_batch_keys_for_biases,
    )
    if nbrhood_results.shape[0]:
        tsvfile = args.outfile_prefix+'_nbrhood_batch_biases.tsv'
        nbrhood_results.to_csv(tsvfile, sep='\t', index=False)

        nbrs_gex, nbrs_tcr = all_nbrs[ max(args.nbr_fracs) ]

        conga.plotting.make_batch_bias_plots(
            adata,
            nbrhood_results,
            nbrs_gex,
            nbrs_tcr,
            args.min_cluster_size_for_batch_bias_logos,
            pval_threshold,
            args.outfile_prefix,
        )


    if hotspot_results.shape[0]:
        tsvfile = args.outfile_prefix+'_batch_hotspots.tsv'
        hotspot_results.to_csv(tsvfile, sep='\t', index=False)

    batch_bias_results = (nbrhood_results, hotspot_results)



## make summary plots of top clones and their batch distributions
## also make umaps colored by batch assignment of rep cell
if 'batch_keys' in adata.uns_keys():
    conga.plotting.make_batch_colored_umaps(
        adata, args.outfile_prefix)

    conga_scores, tcr_clumping_pvalues = None, None
    if args.graph_vs_graph:
        conga_scores = adata.obs['conga_scores']
    if args.tcr_clumping:
        tcr_clumping_pvalues = adata.obs['tcr_clumping_pvalues']
    conga.plotting.make_clone_batch_clustermaps(
        adata, args.outfile_prefix, adata.uns['batch_keys'],
        conga_scores = conga_scores,
        tcr_clumping_pvalues = tcr_clumping_pvalues,
        batch_bias_results = batch_bias_results )


# just out of curiosity:
conga.correlations.check_nbr_graphs_indegree_bias(all_nbrs)

if args.find_distance_correlations:
    clusters_gex = np.array(adata.obs['clusters_gex'])
    clusters_tcr = np.array(adata.obs['clusters_tcr'])
    pvalues, rvalues = conga.devel.compute_distance_correlations(adata)
    results = []
    for ii, (pval, rval) in enumerate(zip(rvalues, pvalues)):
        if pval<1:
            results.append(dict(clone_index=ii, pvalue_adj=pval, rvalue=rval,
                                gex_cluster=clusters_gex[ii],
                                tcr_cluster=clusters_tcr[ii]))
    if results:
        results_df = pd.DataFrame(results)
        outfile = args.outfile_prefix+'_distance_correlations.tsv'
        results_df.to_csv(outfile, sep='\t', index=False)

if args.find_pmhc_nbrhood_overlaps:
    agroups, bgroups = conga.preprocess.setup_tcr_groups(adata)

    pmhc_nbrhood_overlap_results = []
    for nbr_frac in args.nbr_fracs:
        nbrs_gex, nbrs_tcr = all_nbrs[nbr_frac]
        for tag, nbrs in [['gex', nbrs_gex], ['tcr', nbrs_tcr]]:
            results_df = conga.pmhc_scoring.compute_pmhc_versus_nbrs(
                adata, nbrs, agroups, bgroups )
            results_df['nbr_tag'] = tag
            results_df['nbr_frac'] = nbr_frac
            pmhc_nbrhood_overlap_results.append( results_df )

    tsvfile = args.outfile_prefix+'_pmhc_versus_nbrs.tsv'
    print('making:', tsvfile)
    pd.concat(pmhc_nbrhood_overlap_results).to_csv(
        tsvfile, index=False, sep='\t')


if args.write_proj_info:
    outfile = args.outfile_prefix+'_2d_proj_info.txt'
    conga.preprocess.write_proj_info( adata, outfile )

try:
    adata.write_h5ad(args.outfile_prefix+'_final.h5ad')
except:
    print('error writing adata to file, dropping the conga_results dict')
    save = adata.uns['conga_results']
    del adata.uns['conga_results']
    adata.write_h5ad(args.outfile_prefix+'_final.h5ad')
    adata.uns['conga_results'] = save

adata.obs.to_csv(args.outfile_prefix+'_final_obs.tsv', sep='\t')

html_summary_file = args.outfile_prefix+'_results_summary.html'
conga.plotting.make_html_summary(
    adata, html_summary_file, command_string = ' '.join(sys.argv),
    title = args.outfile_prefix)

outlog.write('run_conga took {:.3f} minutes\n'\
             .format((time.time()- start_time)/60))

outlog.close()
print('DONE')
